"""
Module containing the encoders.
"""
import numpy as np

import torch
from torch import nn


# ALL encoders should be called Enccoder<Model>
def get_encoder(model_type):
    model_type = model_type.lower().capitalize()
    return eval("Encoder{}".format(model_type))


class EncoderHiggins(nn.Module):
    """Convolutional encoder used in beta-VAE paper for the chairs data.

    Based on row 3 of Table 1 on page 13 of "beta-VAE: Learning Basic Visual
    Concepts with a Constrained Variational Framework"
    (https://openreview.net/forum?id=Sy2fzU9gl)

    Args:
      input_tensor: Input tensor of shape (batch_size, 64, 64, num_channels) to
        build encoder on.
      num_latent: Number of latent variables to output.
      is_training: Whether or not the graph is built for training (UNUSED).

    Returns:
      means: Output tensor of shape (batch_size, num_latent) with latent variable
        means.
      log_var: Output tensor of shape (batch_size, num_latent) with latent
        variable log variances.
    """

    def __init__(self, input_shape, num_latent, base_channel=32):
        super().__init__()
        self.num_latent = num_latent
        self.input_shape = input_shape
        self.net = nn.Sequential(
            nn.Conv2d(input_shape[0], base_channel, (4, 4), stride=2, padding=1), nn.ReLU(),  # 32x32
            nn.Conv2d(base_channel, base_channel, (4, 4), stride=2, padding=1), nn.ReLU(),  # 16
            nn.Conv2d(base_channel, base_channel * 2, (4, 4), stride=2, padding=1), nn.ReLU(),  # 8
            nn.Conv2d(base_channel * 2, base_channel * 2, (4, 4), stride=2, padding=1), nn.ReLU(),  # 4
            nn.Flatten(),
            nn.Linear(4 * 4 * base_channel * 2, 256), nn.ReLU(),
            nn.Linear(256, num_latent * 2)
        )

    def forward(self, input_tensor):
        x = self.net(input_tensor)
        means, log_var = torch.split(x, [self.num_latent] * 2, 1)
        return means, log_var


class EncoderFraction(nn.Module):
    def __init__(self, input_shape, num_latent,
                 base_channel=8,
                 groups=4,
                 ):
        super().__init__()
        assert num_latent % groups == 0
        self.num_latent = num_latent
        self.input_shape = input_shape
        self.groups = groups
        self.active = 0
        convs = [EncoderHiggin(input_shape, num_latent // groups, base_channel)
                 for i in range(groups)]
        self.convs = nn.Sequential(*convs)
        self.activate(0)

    def activate(self, group):
        self.active = group
        for i in range(self.groups):
            if i == group:
                self.convs[i].requires_grad_(True)
            else:
                self.convs[i].requires_grad_(False)

    def forward(self, input_tensor):
        mean_list, log_var_list = [], []
        for conv in self.convs:
            means, log_var = conv(input_tensor)
            mean_list.append(means)
            log_var_list.append(log_var)

        return torch.cat(mean_list, 1), torch.cat(log_var_list, 1)


class EncoderBurgess(nn.Module):
    def __init__(self, img_size,
                 latent_dim=10,
                 hid_channels = 32):
        r"""Encoder of the model proposed in [1].

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).

        latent_dim : int
            Dimensionality of latent output.

        Model Architecture (transposed for decoder)
        ------------
        - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
        - 2 fully connected layers (each of 256 units)
        - Latent distribution:
            - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)

        References:
            [1] Burgess, Christopher P., et al. "Understanding disentangling in
            $\beta$-VAE." arXiv preprint arXiv:1804.03599 (2018).
        """
        super(EncoderBurgess, self).__init__()

        # Layer parameters
        kernel_size = 4
        hidden_dim = 256
        self.latent_dim = latent_dim
        self.img_size = img_size
        # Shape required to start transpose convs
        self.reshape = (hid_channels, kernel_size, kernel_size)
        n_chan = self.img_size[0]

        # Convolutional layers
        cnn_kwargs = dict(stride=2, padding=1)
        self.conv1 = nn.Conv2d(n_chan, hid_channels, kernel_size, **cnn_kwargs)
        self.conv2 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)
        self.conv3 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)

        # If input image is 64x64 do fourth convolution
        if self.img_size[1] == self.img_size[2] == 64:
            self.conv_64 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)

        # Fully connected layers
        self.lin1 = nn.Linear(np.product(self.reshape), hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, hidden_dim)

        # Fully connected layers for mean and variance
        self.mu_logvar_gen = nn.Linear(hidden_dim, self.latent_dim * 2)

    def forward(self, x):
        batch_size = x.size(0)

        # Convolutional layers with ReLu activations
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        if self.img_size[1] == self.img_size[2] == 64:
            x = torch.relu(self.conv_64(x))

        # Fully connected layers with ReLu activations
        x = x.view((batch_size, -1))
        x = torch.relu(self.lin1(x))
        x = torch.relu(self.lin2(x))

        # Fully connected layer for log variance and mean
        # Log std-dev in paper (bear in mind)
        mu_logvar = self.mu_logvar_gen(x)
        mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)

        return mu, logvar


class EncoderLocatello(nn.Module):
    def __init__(self, img_size,
                 latent_dim=10,
                 hid_channels = 32):
        r"""Encoder of the model proposed in [1].

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).

        latent_dim : int
            Dimensionality of latent output.

        Model Architecture (transposed for decoder)
        ------------
        - 4 convolutional layers (each with 32 channels), (4 x 4 kernel), (stride of 2)
        - 2 fully connected layers (each of 256 units)
        - Latent distribution:
            - 1 fully connected layer of 20 units (log variance and mean for 10 Gaussians)

        References:
            [1] Locatello, Challenging common assumptions in the unsupervised learning of disentangled representations(2019).
        """
        super(EncoderLocatello, self).__init__()

        # Layer parameters
        kernel_size = 4
        hidden_dim = 256
        self.latent_dim = latent_dim
        self.img_size = img_size
        # Shape required to start transpose convs
        self.reshape = (64, 4, 4)
        n_chan = self.img_size[0]

        # Convolutional layers
        cnn_kwargs = dict(stride=2, padding=1)
        self.conv1 = nn.Conv2d(n_chan, hid_channels, kernel_size, **cnn_kwargs)
        self.conv2 = nn.Conv2d(hid_channels, hid_channels, kernel_size, **cnn_kwargs)

        self.conv3 = nn.Conv2d(hid_channels, 64, kernel_size, **cnn_kwargs)

        # If input image is 64x64 do fourth convolution
        if self.img_size[1] == self.img_size[2] == 64:
            self.conv_64 = nn.Conv2d(64, 64, kernel_size, **cnn_kwargs)

        # Fully connected layers
        self.lin1 = nn.Linear(np.product(self.reshape), hidden_dim)

        self.mu_gen = nn.Linear(hidden_dim, self.latent_dim)
        self.logvar_gen = nn.Linear(hidden_dim, self.latent_dim)

    def forward(self, x):
        batch_size = x.size(0)
        # Convolutional layers with ReLu activations
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        if self.img_size[1] == self.img_size[2] == 64:
            x = torch.relu(self.conv_64(x))

        # Fully connected layers with ReLu activations
        x = x.view((batch_size, -1))
        x = torch.relu(self.lin1(x))

        # Fully connected layer for log variance and mean
        # Log std-dev in paper (bear in mind)
        mu = self.mu_gen(x)
        logvar = self.logvar_gen(x)

        return mu, logvar
